import os
import shutil

import torch
from torch.utils.data import DataLoader, random_split

from src.my_dataset import MyDataset
from src.my_vgg import vgg16

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

my_dataset = MyDataset('../dataset', '../dataset/dataset.csv', (224, 224))
num_sample = len(my_dataset)

train_percent = 0.8
train_num = int(train_percent * num_sample)
test_num = num_sample - train_num
train_set, test_set = random_split(my_dataset, [train_num, test_num], generator=torch.Generator().manual_seed(42))

# 准备数据集和dataLoader
# train_set = MyDataset('./data/train', './data/train/dataset.csv', (224, 224))
# test_set = MyDataset('./data/val', './data/val/dataset.csv', (224, 224))

train_dataLoader = DataLoader(train_set, 4, True)
test_dataLoader = DataLoader(test_set, 4, True)

# 加载模型、损失函数、优化器
# module = torchvision.models.vgg16(pretrained=True)
model = vgg16(True, True, 5)
# module.classifier[6] = nn.Linear(4096, 2)
pth_path = "../model/new_vgg16_trained.pth"
if os.path.exists(pth_path):
    model.load_state_dict(torch.load(pth_path))
model.to(device)
loss_fn = torch.nn.CrossEntropyLoss()
loss_fn.to(device)

learning_rate = 0.0001
optim = torch.optim.SGD(model.parameters(), learning_rate)

# 训练轮数、总的训练次数、训练集大小、SummaryWriter
epoch = 10
total_train_step = 0
test_date_size = len(test_set)

for i in range(epoch):
    print("-----第{}轮训练-----".format(i + 1))

    model.train()
    for data in train_dataLoader:
        images, targets = data
        images = images.to(device)
        targets = targets.to(device)
        output = model(images)
        loss_value = loss_fn(output, targets)
        optim.zero_grad()
        loss_value.backward()
        optim.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print("训练次数：{}，Loss：{}".format(total_train_step, loss_value))

    model.eval()
    total_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataLoader:
            images, targets = data
            images = images.to(device)
            targets = targets.to(device)
            output = model(images)

            loss_value = loss_fn(output, targets)
            total_loss += loss_value

            accuracy = (output.argmax(1) == targets).sum()
            total_accuracy += accuracy

        print("测试集上的loss：{}".format(total_loss))
        print("测试集上的准确率：{}".format(total_accuracy / test_date_size))

    torch.cuda.empty_cache()
    torch.save(model.state_dict(), "../model/new_vgg16_trained_{}.pth".format(i))
    if os.path.exists("../model/new_vgg16_trained_{}.pth".format(i - 1)):
        os.remove("../model/new_vgg16_trained_{}.pth".format(i - 1))
    # torch.save(model.state_dict(), pth_path)
    shutil.copy("../model/new_vgg16_trained_{}.pth".format(i), pth_path)
